Compare DMS to natural sequence evolution¶

In [1]:
# this cell is tagged parameters for papermill parameterization
dms_summary_csv = None
growth_rates_csv = None
pango_consensus_seqs_json = None
starting_clades = None
dms_clade = None
n_random = None
exclude_clades = None
pango_dms_phenotypes_csv = None
pango_by_date_html = None
pango_affinity_vs_escape_html = None
pango_dms_vs_growth_regression_html = None
pango_dms_vs_growth_regression_by_domain_html = None
pango_dms_vs_growth_corr_html = None
pango_dms_vs_growth_corr_by_domain_html = None
exclude_clades_with_muts = None
In [2]:
# Parameters
starting_clades = ["BA.2", "BA.5", "XBB"]
dms_clade = "XBB.1.5"
dms_summary_csv = "results/summaries/summary.csv"
growth_rates_csv = "data/2023-09-18_Murrell_growth_estimates.csv"
pango_consensus_seqs_json = (
    "results/compare_natural/pango-consensus-sequences_summary.json"
)
pango_dms_phenotypes_csv = "results/compare_natural/pango_dms_phenotypes.csv"
pango_by_date_html = "results/compare_natural/pango_dms_phenotypes_by_date.html"
pango_affinity_vs_escape_html = "results/compare_natural/pango_affinity_vs_escape.html"
pango_dms_vs_growth_regression_html = (
    "results/compare_natural/pango_dms_vs_growth_regression.html"
)
pango_dms_vs_growth_regression_by_domain_html = (
    "results/compare_natural/pango_dms_vs_growth_regression_by_domain.html"
)
pango_dms_vs_growth_corr_html = "results/compare_natural/pango_dms_vs_growth_corr.html"
pango_dms_vs_growth_corr_by_domain_html = (
    "results/compare_natural/pango_dms_vs_growth_corr_by_domain.html"
)
n_random = 200
exclude_clades = []
exclude_clades_with_muts = []
In [3]:
import collections
import itertools
import json
import math
import re

import altair as alt

import numpy

import pandas as pd

import polyclonal.plot

import scipy.stats

import statsmodels.api

_ = alt.data_transformers.disable_max_rows()

Read Pango clades and mutations¶

In [4]:
with open(pango_consensus_seqs_json) as f:
    pango_clades = json.load(f)

def n_child_clades(c):
    """Get number of children clades of a Pango clade."""
    direct_children = pango_clades[c]["children"]
    return len(direct_children) + sum([n_child_clades(c_child) for c_child in direct_children])

def build_records(c, recs):
    """Build records of Pango clade information."""
    if c in recs["clade"]:
        return
    recs["clade"].append(c)
    recs["n_child_clades"].append(n_child_clades(c))
    recs["date"].append(pango_clades[c]["designationDate"])
    recs["muts_from_ref"].append(
        [
            mut.split(":")[1]
            for field in ["aaSubstitutions", "aaDeletions"]
            for mut in pango_clades[c][field]
            if mut.startswith("S:")
        ]
    )
    for c_child in pango_clades[c]["children"]:
        build_records(c_child, recs)
        
records = collections.defaultdict(list)
for starting_clade in starting_clades:
    build_records(starting_clade, records)

pango_df = pd.DataFrame(records).query("clade not in @exclude_clades")
dms_clade_mutations_from_ref = pango_df.set_index("clade").at[
    dms_clade, "muts_from_ref"
]

def mutations_from(muts, from_muts):
    """Get mutations from another sequence."""
    new_muts = set(muts).symmetric_difference(from_muts)
    assert all(re.fullmatch("[A-Z\-]\d+[A-Z\-]", m) for m in new_muts)
    new_muts_d = collections.defaultdict(list)
    for m in new_muts:
        new_muts_d[int(m[1: -1])].append(m)
    new_muts_list = []
    for _, ms in sorted(new_muts_d.items()):
        if len(ms) == 1:
            m = ms[0]
            if m in muts:
                new_muts_list.append(m)
            else:
                assert m in from_muts
                new_muts_list.append(m[-1] + m[1: -1] + m[0])
        else:
            m, from_m = ms
            if m not in muts:
                from_m, m = m, from_m
            assert m in muts and from_m in from_muts
            new_muts_list.append(from_m[-1] + m[1: ])
    return new_muts_list

pango_df = (
    pango_df
    .assign(
        muts_from_dms_clade=lambda x: x["muts_from_ref"].apply(
            mutations_from, args=(dms_clade_mutations_from_ref,),
        ),
        date=lambda x: pd.to_datetime(x["date"]),
    )
    .drop(columns="muts_from_ref")
    .sort_values("date")
    .reset_index(drop=True)
)

for mut in exclude_clades_with_muts:
    pango_df = pango_df[pango_df["muts_from_dms_clade"].map(lambda ms: mut not in ms)]

pango_df
Out[4]:
clade n_child_clades date muts_from_dms_clade
0 BA.2 377 2021-12-07 [A83V, -144Y, Q146H, E183Q, E213G, V252G, H339...
1 BA.2.1 0 2022-02-25 [A83V, -144Y, Q146H, E183Q, E213G, V252G, H339...
2 BA.2.2 1 2022-02-25 [A83V, -144Y, Q146H, E183Q, E213G, V252G, H339...
3 BA.2.3 52 2022-02-25 [A83V, -144Y, Q146H, E183Q, E213G, V252G, H339...
4 BA.2.4 0 2022-03-25 [A83V, -144Y, Q146H, E183Q, E213G, V252G, H339...
... ... ... ... ...
1534 XBB.1.5.106 0 2023-09-17 [A623V]
1535 JG.3 0 2023-09-17 [Q52H, L455F, F456L, S704L]
1536 EG.5.1.9 0 2023-09-17 [Q52H, L176F, F456L]
1537 GK.4 0 2023-09-17 [L455F, F456L, A475V]
1538 GK.2.3 0 2023-09-17 [K356T, L455F, F456L, V511I]

1539 rows × 4 columns

Assign DMS phenotypes to Pango clades¶

First define function that assigns DMS phenotypes to mutations:

In [5]:
# read the DMS data
dms_summary = pd.read_csv(dms_summary_csv).rename(
    columns={
        "spike mediated entry": "cell entry",
        "human sera escape": "sera escape",
    }
)

# specify DMS phenotypes of interest
phenotypes = [
    "sera escape",
    "ACE2 affinity",
    "cell entry",
]
assert set(phenotypes).issubset(dms_summary.columns)

phenotype_colors = {
    "sera escape": "red",
    "ACE2 affinity": "blue",
    "cell entry": "purple",
}
assert set(phenotypes) == set(phenotype_colors)


# dict that maps site to wildtype in DMS
dms_wt = dms_summary.set_index("site")["wildtype"].to_dict()

# dict that maps site to region in DMS
site_to_region = dms_summary.set_index("site")["region"].to_dict()

def mut_dms(m, dms_data):
    """Get DMS phenotypes for a mutation."""
    null_d = {k: pd.NA for k in phenotypes}
    if pd.isnull(m) or int(m[1: -1]) not in dms_wt:
        d = null_d
        d["is_RBD"] = pd.NA
    else:
        parent = m[0]
        site = int(m[1: -1])
        mut = m[-1]
        wt = dms_wt[site]
        if parent == wt:
            try:
                d = dms_data[(site, parent, mut)]
            except KeyError:
                d = null_d
        elif mut == wt:
            try:
                d = {k: -v for (k, v) in dms_data[(site, mut, parent)].items()}
            except KeyError:
                d = null_d
        else:
            try:
                parent_d = dms_data[(site, wt, parent)]
                mut_d = dms_data[(site, wt, mut)]
                d = {p: mut_d[p] - parent_d[p] for p in phenotypes}
            except KeyError:
                d = null_d
        d["is_RBD"] = (site_to_region[site] == "RBD")
    assert list(d) == phenotypes + ["is_RBD"]
    return d

Now assign phenotypes to pango clades. We do this both using the actual DMS data and randomizing the DMS data among measured mutations:

In [6]:
def get_pango_dms_df(dms_data_dict):
    """Given dict mapping mutations to DMS data, get data frame of values for Pango clades."""
    pango_dms_df = (
        pango_df
        # put one mutation in each column
        .explode("muts_from_dms_clade")
        .rename(columns={"muts_from_dms_clade": "mutation"})
        # to add multiple columns: https://stackoverflow.com/a/46814360
        .apply(
            lambda cols: pd.concat([cols, pd.Series(mut_dms(cols["mutation"], dms_data_dict))]),
            axis=1,
        )
        .melt(
            id_vars=["clade", "date", "n_child_clades", "mutation", "is_RBD"],
            value_vars=phenotypes,
            var_name="DMS_phenotype",
            value_name="mutation_effect",
        )
        .assign(
            muts_from_dms_clade=lambda x: x.groupby(["clade", "DMS_phenotype"])["mutation"].transform(
                lambda ms: "; ".join([m for m in ms if not pd.isnull(m)])
            ),
            mutation_missing=lambda x: x["mutation"].where(
                x["mutation_effect"].isnull() & x["mutation"].notnull(),
                pd.NA,
            ),
            muts_from_dms_clade_missing_data=lambda x: (
                x.groupby(["clade", "DMS_phenotype"])["mutation_missing"]
                .transform(lambda ms: "; ".join([m for m in ms if not pd.isnull(m)]))
            ),
            mutation_effect=lambda x: x["mutation_effect"].fillna(0),
            is_RBD=lambda x: x["is_RBD"].fillna(False),
            mutation_effect_RBD=lambda x: x["mutation_effect"] * x["is_RBD"].astype(int),
            mutation_effect_nonRBD=lambda x: x["mutation_effect"] * (~x["is_RBD"]).astype(int),
        )
        .groupby(
            [
                "clade",
                "date",
                "n_child_clades",
                "muts_from_dms_clade",
                "muts_from_dms_clade_missing_data",
                "DMS_phenotype",
            ],
            as_index=False,
        )
        .aggregate(
            phenotype=pd.NamedAgg("mutation_effect", "sum"),
            phenotype_RBD_only=pd.NamedAgg("mutation_effect_RBD", "sum"),
            phenotype_nonRBD_only=pd.NamedAgg("mutation_effect_nonRBD", "sum"),
        )
        .rename(
            columns={
                "muts_from_dms_clade": f"muts_from_{dms_clade}",
                "muts_from_dms_clade_missing_data": f"muts_from_{dms_clade}_missing_data",
            },
        )
        .sort_values(["date", "DMS_phenotype"])
        .reset_index(drop=True)
    )
    
    assert set(pango_df["clade"]) == set(pango_dms_df["clade"])
    assert numpy.allclose(
        pango_dms_df["phenotype"],
        pango_dms_df["phenotype_RBD_only"] + pango_dms_df["phenotype_nonRBD_only"]
    )

    return pango_dms_df

# First, get the actual DMS data mapped to phenotype
dms_data_dict_actual = (
    dms_summary
    .set_index(["site", "wildtype", "mutant"])
    [phenotypes]
    .to_dict(orient="index")
)
pango_dms_df = get_pango_dms_df(dms_data_dict_actual)
print(f"Saving Pango DMS phenotypes to {pango_dms_phenotypes_csv}")
pango_dms_df.to_csv(pango_dms_phenotypes_csv, float_format="%.4f", index=False)

# Now get the randomized DMS data mapped to phenotype
pango_dms_dfs_rand = []
numpy.random.seed(0)
for irandom in range(1, n_random + 1):
    # randomize the non-null DMS data for each phenotype
    dms_summary_rand = dms_summary.copy()
    for phenotype in phenotypes:
        dms_summary_rand = dms_summary_rand.assign(
            **{phenotype: lambda x: numpy.random.permutation(x[phenotype].values)}
        )
    dms_data_dict_rand = (
        dms_summary_rand
        .set_index(["site", "wildtype", "mutant"])
        [phenotypes]
        .to_dict(orient="index")
    )
    pango_dms_dfs_rand.append(get_pango_dms_df(dms_data_dict_rand).assign(randomize=irandom))
# all randomizations concatenated
pango_dms_df_rand = pd.concat(pango_dms_dfs_rand)
Saving Pango DMS phenotypes to results/compare_natural/pango_dms_phenotypes.csv

Plot phenotypes of Pango clades¶

Plot phenotypes of Pango clades versus their designation dates:

In [7]:
region_cols = {
    "phenotype": "full spike",
    "phenotype_RBD_only": "RBD only",
    "phenotype_nonRBD_only": "non-RBD only",
}

pango_chart_df = (
    pango_dms_df
    .melt(
        id_vars=[c for c in pango_dms_df if c not in region_cols],
        value_vars=region_cols,
        var_name="spike_region",
        value_name="phenotype value",
    )
    .assign(
        spike_region=lambda x: x["spike_region"].map(region_cols),
    )
    .rename(columns={f"muts_from_{dms_clade}_missing_data": "muts_missing_data"})
)

# columns cannot have "." in them for Altair
col_renames = {c: c.replace(".", "_") for c in pango_chart_df.columns if "." in c}
col_renames_rev = {v: k for (k, v) in col_renames.items()}
pango_chart_df = pango_chart_df.rename(columns=col_renames)

clade_selection = alt.selection_point(fields=["clade"], on="mouseover", empty=False)

base_pango_chart = (
    alt.Chart(pango_chart_df)
    .encode(
        tooltip=[
            alt.Tooltip(c, title=col_renames_rev[c] if c in col_renames_rev else c)
            for c in pango_chart_df.columns
        ],
        opacity=alt.condition(clade_selection, alt.value(1), alt.value(0.35)),
        size=alt.condition(clade_selection, alt.value(60), alt.value(40)),
        strokeWidth=alt.condition(clade_selection, alt.value(2), alt.value(0)),
        color=alt.Color(
            "DMS_phenotype",
            legend=None,
            scale=alt.Scale(
                range=list(phenotype_colors.values()),
                domain=list(phenotype_colors.keys()),
            ),
        ),
    )
    .mark_circle(stroke="black")
    .properties(width=300, height=125)
)

phenotype_pango_charts = []
for phenotype in phenotypes:
    first_row = (phenotype == phenotypes[0])
    last_row = (phenotype == phenotypes[-1])
    phenotype_pango_charts.append(
        base_pango_chart
        .transform_filter(alt.datum["DMS_phenotype"] == phenotype)
        .encode(
            x=alt.X(
                "date",
                title="designation date of clade" if last_row else None,
                axis=(
                    alt.Axis(titleFontSize=12, labelOverlap=True, format="%b-%Y", labelAngle=-90)
                    if last_row
                    else None
                ),
                scale=alt.Scale(nice=False, padding=3),
            ),
            y=alt.Y(
                "phenotype value",
                title=phenotype,
                axis=alt.Axis(titleFontSize=12),
                scale=alt.Scale(nice=False, padding=3),
            ),
            column=alt.Column(
                "spike_region",
                sort=list(region_cols),
                title=None,
                header=(
                    alt.Header(labelFontSize=12, labelFontStyle="bold", labelPadding=4)
                    if first_row
                    else None
                ),
                spacing=4,
            ),
        )
    )

pango_chart = (
    alt.vconcat(*phenotype_pango_charts, spacing=4)
    .configure_axis(grid=False)
    .add_params(clade_selection)
    .properties(        
        title=alt.TitleParams(
            f"DMS predicted phenotypes of Pango clades descended from {', '.join(starting_clades)}",
            anchor="middle",
            fontSize=16,
            dy=-5,
        ),
    )
)

print(f"Saving chart to {pango_by_date_html}")
pango_chart.save(pango_by_date_html)

pango_chart
Saving chart to results/compare_natural/pango_dms_phenotypes_by_date.html
Out[7]:

Pango clade affinity versus escape scatter plot¶

In [8]:
pango_scatter_df = (
    pango_dms_df
    .pivot_table(
        index=[
            c
            for c in pango_dms_df
            if c not in {"DMS_phenotype", "phenotype", "phenotype_RBD_only", "phenotype_nonRBD_only"}
        ],
        values="phenotype",
        columns="DMS_phenotype",
    )
    .reset_index()
    .rename(columns={f"muts_from_{dms_clade}_missing_data": "muts_missing_data"})
    .rename(columns=col_renames)
)

pango_scatter_df

pango_scatter_chart = (
    alt.Chart(pango_scatter_df)
    .encode(
        x=alt.X(
            "ACE2 affinity",
            axis=alt.Axis(titleFontSize=12),
            scale=alt.Scale(nice=False, padding=5),
        ),
        y=alt.Y(
            "sera escape",
            axis=alt.Axis(titleFontSize=12),
            scale=alt.Scale(nice=False, padding=5),
        ),
        tooltip=[
            alt.Tooltip(c, title=col_renames_rev[c] if c in col_renames_rev else c)
            for c in pango_scatter_df.columns
        ],
        opacity=alt.condition(clade_selection, alt.value(1), alt.value(0.35)),
        size=alt.condition(clade_selection, alt.value(100), alt.value(55)),
        strokeWidth=alt.condition(clade_selection, alt.value(2), alt.value(0)),
    )
    .mark_circle(stroke="red", color="black")
    .add_params(clade_selection)
    .configure_axis(grid=False)
    .properties(        
        title=alt.TitleParams(
            [
                "DMS predicted ACE2 affinity vs serum escape",
                f"for Pango clades descended from {starting_clade}"
            ],
            anchor="middle",
            fontSize=14,
            dy=-5,
        ),
    )
    .properties(width=300, height=300)
)

print(f"Saving chart to {pango_affinity_vs_escape_html}")
pango_scatter_chart.save(pango_affinity_vs_escape_html)

pango_scatter_chart
Saving chart to results/compare_natural/pango_affinity_vs_escape.html
Out[8]:

Correlate with clade growth¶

In [9]:
growth_rates = pd.read_csv(growth_rates_csv).rename(
    columns={"pango": "clade", "seq_volume": "number sequences"}
)

if (invalid_clades := set(growth_rates["clade"]) - set(pango_clades)):
    raise ValueError(f"Growth rates specified for {invalid_clades}")

pango_dms_growth_df = pango_dms_df.merge(growth_rates, on="clade", validate="many_to_one")

pango_dms_growth_df_rand = pango_dms_df_rand.merge(growth_rates, on="clade", validate="many_to_one")

print(
    f"{growth_rates['clade'].nunique()} clades have growth rates estimates.\n"
    f"{pango_dms_df['clade'].nunique()} clades have DMS estimates.\n"
    f"{pango_dms_growth_df['clade'].nunique()} clades have growth and DMS estimates"
)

print("Simple correlations:")
display(
    pango_dms_growth_df
    .groupby("DMS_phenotype")
    [["R", "phenotype", "phenotype_RBD_only", "phenotype_nonRBD_only"]]
    .corr()
    [["R"]]
)
966 clades have growth rates estimates.
1539 clades have DMS estimates.
900 clades have growth and DMS estimates
Simple correlations:
R
DMS_phenotype
ACE2 affinity R 1.000000
phenotype -0.480301
phenotype_RBD_only -0.305674
phenotype_nonRBD_only -0.309355
cell entry R 1.000000
phenotype 0.797371
phenotype_RBD_only 0.820969
phenotype_nonRBD_only 0.448803
sera escape R 1.000000
phenotype 0.934447
phenotype_RBD_only 0.930579
phenotype_nonRBD_only 0.391210

Plot number of sequences versus date, with sizes proportional to log of number of sequences in clade:

In [10]:
(
    alt.Chart(pango_dms_growth_df)
    .encode(
        x="date",
        y="R",
        size=alt.Size("number sequences", scale=alt.Scale(type="log")),
        tooltip=pango_dms_growth_df.columns.tolist(),
    )
    .mark_circle(opacity=0.25, color="black")
)
Out[10]:

Now perform OLS, weighting clades by log number of sequences:

In [11]:
# pivot DMS data to get phenotypes
def pivot_for_ols_vars(df):
    ols_vars = (
        df
        .rename(
            columns={
                "phenotype": "full spike",
                "phenotype_RBD_only": "RBD",
                "phenotype_nonRBD_only": "non RBD",
            }
        )
        .assign(
            # group muts missing data from all phenotypes
            muts_from_DMS_clade_missing_data=lambda x: (
                x.groupby("clade")
                [f"muts_from_{dms_clade}_missing_data"]
                .transform(
                    lambda s: "; ".join(dict.fromkeys([m for ms in s.str.split("; ") for m in ms if m]))
                )
            ),
        )
        .rename(columns={f"muts_from_{dms_clade}": "muts_from_DMS_clade"})
        .pivot_table(
            index=[
                "clade",
                "R",
                "date",
                "muts_from_DMS_clade",
                "muts_from_DMS_clade_missing_data",
                "number sequences",
            ],
            columns="DMS_phenotype",
            values=["full spike", "RBD", "non RBD"],
        )
    )
    # flatten column names
    assert all(len(c) == 2 for c in ols_vars.columns.values)
    ols_vars.columns = [f"{pheno} ({domain})" for domain, pheno in ols_vars.columns.values]
    return ols_vars.reset_index()

ols_vars = pivot_for_ols_vars(pango_dms_growth_df)

# https://www.einblick.ai/python-code-examples/ordinary-least-squares-regression-statsmodels/
for name, exog_vars, regression_chartfile, corr_chartfile in [
    (
        "full spike",
        [f"{c} (full spike)" for c in phenotypes],
        pango_dms_vs_growth_regression_html,
        pango_dms_vs_growth_corr_html
    ),
    (
        "separate RBD and non-RBD",
        [f"{c} ({d})" for d in ["RBD", "non RBD"] for c in phenotypes],
        pango_dms_vs_growth_regression_by_domain_html,
        pango_dms_vs_growth_corr_by_domain_html,
    ),
]:
    print(f"\n\nFitting for {name}:")
    ols_model = statsmodels.api.WLS(
        endog=ols_vars[["R"]],
        exog=statsmodels.api.add_constant(ols_vars[exog_vars]),
        # weight by log n sequences, so pass log**2
        weights=numpy.log(ols_vars["number sequences"])**2,
    )
    res_ols = ols_model.fit()
    display(res_ols.summary())

    fitted_df = ols_vars.assign(DMS_predicted_growth=res_ols.predict())

    plot_size=180
    
    clade_selection = alt.selection_point(fields=["clade"], on="mouseover", empty=False)

    n_sequences_init = int(10 * math.log10(fitted_df["number sequences"].min())) / 10
    n_sequences_slider = alt.param(
        value=n_sequences_init,
        bind=alt.binding_range(
            name="minimum log10 number sequences in clade",
            min=n_sequences_init,
            max=math.log10(fitted_df["number sequences"].max() / 10),
        ),
    )

    # date slider: https://stackoverflow.com/a/67941109
    select_date = alt.selection_interval(encodings=["x"])
    date_slider = (
        alt.Chart(fitted_df[["clade", "date"]].drop_duplicates())
        .mark_bar(color="black")
        .encode(
            x=alt.X(
                "date",
                title="zoom bar to select clades by designation date",
                axis=alt.Axis(format="%b-%Y"),
            ),
            y=alt.Y("count()", title=["number", "clades"]),
        )
        .properties(width=1.5 * plot_size, height=45)
        .add_params(select_date)
    )
    
    base_growth_chart = (
        alt.Chart(fitted_df)
        .transform_filter(
            alt.expr.log(alt.datum["number sequences"]) / math.log(10) >= n_sequences_slider
        )
        .transform_filter(select_date)
        .encode(
            size=alt.Size(
                "number sequences",
                scale=alt.Scale(
                    type="log",
                    nice=False,
                    range=[15, 250],
                ),
                legend=alt.Legend(symbolStrokeWidth=0, symbolFillColor="gray"),
            ),
            strokeWidth=alt.condition(clade_selection, alt.value(2), alt.value(0.5)),
            strokeOpacity=alt.condition(clade_selection, alt.value(1), alt.value(0.5)),
            tooltip=[
                "clade",
                alt.Tooltip("R", title="growth rate (R)", format=".1f"),
                alt.Tooltip("DMS_predicted_growth", title="DMS predicted growth", format=".1f"),
                alt.Tooltip("number sequences", format=".2g"),
                alt.Tooltip("date", title="designation date"),
                alt.Tooltip("muts_from_DMS_clade", title=f"muts from {dms_clade}"),
                alt.Tooltip("muts_from_DMS_clade_missing_data", title="muts missing DMS data"),
                *[alt.Tooltip(v, format=".2f") for v in exog_vars],  
            ],
        )
        .properties(width=plot_size, height=plot_size)
        .add_params(clade_selection, n_sequences_slider)
    )

    growth_charts = []
    simple_corr_charts = []
    for i, (dms_pheno, pheno) in enumerate(zip(
        exog_vars,
        itertools.cycle(phenotypes)
    )):
        assert dms_pheno.startswith(pheno)
        base_pheno_chart = (
            base_growth_chart
            .encode(
                y=alt.Y(
                    "R",
                    title="actual clade growth rate (R)",
                    scale=alt.Scale(nice=False, padding=5, zero=False),
                    axis=None if i % len(phenotypes) else alt.Axis(),
                ),
            )
        )

        growth_charts.append(
            base_pheno_chart
            .encode(
                x=alt.X(
                    "DMS_predicted_growth",
                    title="DMS predicted clade growth",
                    scale=alt.Scale(nice=False, padding=5, zero=False),
                ),
                color=alt.Color(
                    dms_pheno,
                    title=None,
                    legend=alt.Legend(
                        orient="top",
                        titleFontSize=12,
                        gradientLength=plot_size,
                        gradientThickness=10,
                        offset=5,
                        tickCount=3,
                    ),
                    scale=alt.Scale(
                        range=polyclonal.plot.color_gradient_hex("lightgray", phenotype_colors[pheno], 40),
                        nice=False,
                    ),
                ),
            )
            .mark_circle(stroke="black", fillOpacity=0.6)
            .properties(
                title=alt.TitleParams(
                    text=dms_pheno,
                    subtitle=(
                        f"coefficient: {res_ols.params[dms_pheno]:.1f} "
                        # https://stackoverflow.com/a/53966201
                        + f"\u00B1 {res_ols.bse[dms_pheno]:.1f}, "
                        + f"P: {res_ols.pvalues[dms_pheno]:.1g}"
                    ),
                    subtitleFontSize=11,
                ),
            )
        )

        pheno_r, pheno_p = scipy.stats.pearsonr(fitted_df["R"], fitted_df[dms_pheno])        
        simple_corr_charts.append(
            base_pheno_chart
            .transform_calculate(color_phenotype=f"'{pheno}'")
            .encode(
                x=alt.X(
                    dms_pheno,
                    scale=alt.Scale(nice=False, padding=5, zero=False),
                ),
                color=alt.Color(
                    "color_phenotype:N",
                    scale=alt.Scale(
                        range=list(phenotype_colors.values()),
                        domain=list(phenotype_colors.keys()),
                    ),
                    legend=None,
                ),
            )
            .mark_circle(stroke="black", fillOpacity=0.3, color=phenotype_colors[pheno])
            .properties(
                title=alt.TitleParams(
                    text=dms_pheno,
                    subtitle=f"Pearson r: {pheno_r:.2f}",
                    subtitleFontSize=11,
                ),
            )
        )
            
    actual_r = math.sqrt(res_ols.rsquared)
    assert len(growth_charts) % len(phenotypes) == 0
    growth_chart = (
        alt.vconcat(
            alt.vconcat(
                *[
                    alt.hconcat(
                        *growth_charts[i * len(phenotypes): (i + 1) * len(phenotypes)], spacing=13
                    ).resolve_scale(color="independent")
                    for i in range(len(growth_charts) // len(phenotypes))
                ],
                spacing=13,
            ),
            date_slider,
        )
        .properties(
            title=alt.TitleParams(
                f"Weighted linear regression of DMS phenotypes vs clade growth (Pearson r = {actual_r:.2f})",
                anchor="middle",
                fontSize=14,
                dy=-5,
            ),
        )
        .configure_axis(grid=False)
    )

    simple_corr_chart = (
        alt.vconcat(
            alt.vconcat(
                *[
                    alt.hconcat(
                        *simple_corr_charts[i * len(phenotypes): (i + 1) * len(phenotypes)], spacing=13
                    )
                    for i in range(len(simple_corr_charts) // len(phenotypes))
                ],
                spacing=13,
            ),
            date_slider,
        )
        .properties(
            title=alt.TitleParams(
                "Simple correlations of DMS phenotypes vs clade growth",
                anchor="middle",
                fontSize=14,
                dy=-5,
            ),
        )
        .configure_axis(grid=False)
    )
    
    display(growth_chart)
    print(f"Saving to {regression_chartfile}")
    growth_chart.save(regression_chartfile)

    display(simple_corr_chart)
    print(f"Saving to {corr_chartfile}")
    simple_corr_chart.save(corr_chartfile)

    # fit randomized models and compute P-value based on R values
    print("Computing P-value from randomizations")
    rand_r = []
    for randomseed, rand_df in pango_dms_growth_df_rand.groupby("randomize"):
        rand_ols_vars = pivot_for_ols_vars(rand_df)
        rand_ols_model = statsmodels.api.WLS(
            endog=rand_ols_vars[["R"]],
            exog=statsmodels.api.add_constant(rand_ols_vars[exog_vars]),
            # weight by log n sequences, so pass log**2
            weights=numpy.log(rand_ols_vars["number sequences"])**2,
        )
        rand_res_ols = rand_ols_model.fit()
        rand_r.append(math.sqrt(rand_res_ols.rsquared))
    n_rand_ge = sum(r >= actual_r for r in rand_r)
    pval = f"= {n_rand_ge / len(rand_r)}" if n_rand_ge else f"< {1 / len(rand_r)}"
    
    rand_r_hist = (
        alt.Chart(pd.DataFrame({"r": rand_r}))
        .encode(
            x=alt.X(
                "r",
                title="Pearson r",
                bin=alt.BinParams(step=0.02, extent=(0, 1)),
                scale=alt.Scale(domain=(0, 1)),
                axis=alt.Axis(values=[0, 0.2, 0.4, 0.6, 0.8, 1]),
            ),
            y=alt.Y("count()", title="number of randomizations"),
        )
        .mark_bar(color="black", opacity=0.65, align="right")
        .properties(width=250, height=130)
    )
    
    actual_r_line = (
        alt.Chart(pd.DataFrame({"r": [actual_r]}))
        .encode(x="r")
        .mark_rule(size=2, color="red", strokeDash=[4, 2])
    )
    
    pval_chart = (
        (rand_r_hist + actual_r_line)
        .configure_axis(grid=False)
        .properties(
            title=alt.TitleParams(
                f"P {pval}",
                subtitle=f"{n_rand_ge} of {len(rand_r)} randomizations 	\u2265 actual r of {actual_r:.2f}",
            ),
        )
    )
    
    display(pval_chart)

Fitting for full spike:
WLS Regression Results
Dep. Variable: R R-squared: 0.888
Model: WLS Adj. R-squared: 0.888
Method: Least Squares F-statistic: 2376.
Date: Fri, 22 Sep 2023 Prob (F-statistic): 0.00
Time: 14:47:34 Log-Likelihood: -3484.4
No. Observations: 900 AIC: 6977.
Df Residuals: 896 BIC: 6996.
Df Model: 3
Covariance Type: nonrobust
coef std err t P>|t| [0.025 0.975]
const 59.7536 0.730 81.888 0.000 58.321 61.186
sera escape (full spike) 23.9263 0.554 43.149 0.000 22.838 25.015
ACE2 affinity (full spike) 5.7993 1.283 4.520 0.000 3.281 8.317
cell entry (full spike) 14.5368 2.211 6.574 0.000 10.197 18.877
Omnibus: 8.301 Durbin-Watson: 0.746
Prob(Omnibus): 0.016 Jarque-Bera (JB): 8.505
Skew: 0.194 Prob(JB): 0.0142
Kurtosis: 3.275 Cond. No. 13.1


Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
Saving to results/compare_natural/pango_dms_vs_growth_regression.html
Saving to results/compare_natural/pango_dms_vs_growth_corr.html
Computing P-value from randomizations

Fitting for separate RBD and non-RBD:
WLS Regression Results
Dep. Variable: R R-squared: 0.897
Model: WLS Adj. R-squared: 0.896
Method: Least Squares F-statistic: 1291.
Date: Fri, 22 Sep 2023 Prob (F-statistic): 0.00
Time: 14:48:07 Log-Likelihood: -3449.7
No. Observations: 900 AIC: 6913.
Df Residuals: 893 BIC: 6947.
Df Model: 6
Covariance Type: nonrobust
coef std err t P>|t| [0.025 0.975]
const 60.9830 0.727 83.852 0.000 59.556 62.410
sera escape (RBD) 28.7096 0.818 35.078 0.000 27.103 30.316
ACE2 affinity (RBD) 5.7018 1.388 4.107 0.000 2.977 8.427
cell entry (RBD) -16.8042 4.324 -3.886 0.000 -25.291 -8.317
sera escape (non RBD) 45.8783 4.781 9.596 0.000 36.495 55.261
ACE2 affinity (non RBD) 11.6336 2.009 5.790 0.000 7.690 15.577
cell entry (non RBD) 24.7628 3.033 8.165 0.000 18.810 30.715
Omnibus: 20.534 Durbin-Watson: 0.835
Prob(Omnibus): 0.000 Jarque-Bera (JB): 33.125
Skew: 0.179 Prob(JB): 6.41e-08
Kurtosis: 3.869 Cond. No. 28.5


Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
Saving to results/compare_natural/pango_dms_vs_growth_regression_by_domain.html
Saving to results/compare_natural/pango_dms_vs_growth_corr_by_domain.html
Computing P-value from randomizations

Distributions of DMS mutation effects in clades with growth estimates versus all mutations¶

In [12]:
muts_in_clades = collections.Counter(
    pango_dms_growth_df
    [f"muts_from_{dms_clade}"]
    .pipe(lambda s: s[s != ""])
    .str.split("; ")
    .explode()
)
print(f"There are {len(muts_in_clades)} mutations found in clades with growth estimates")

all_muts_dms = (
    dms_summary
    .query("wildtype != mutant")
    .assign(mutation=lambda x: x["wildtype"] + x["site"].astype(str) + x["mutant"])
    .assign(region=lambda x: x["region"].where(x["region"] == "RBD", "non RBD"))
    .melt(
        id_vars=["mutation", "region"],
        value_vars=phenotypes,
        var_name="DMS_phenotype",
        value_name="phenotype",
    )
    .query("phenotype.notnull()") 
)

all_muts_dms = pd.concat(
    [
        all_muts_dms.assign(mutation_type="any", count=1),
        all_muts_dms.query("mutation in @muts_in_clades").assign(
            mutation_type="in Pango clade",
            count=lambda x: x["mutation"].map(muts_in_clades),
        ),
    ]
)

for pheno in phenotypes:
    
    base_hist = (
        alt.Chart(
            all_muts_dms
            .query("DMS_phenotype == @pheno")
            .drop(columns=["DMS_phenotype", "mutation"])
        )
        .encode(
            x=alt.X("phenotype", bin=alt.BinParams(maxbins=50)),
            y=alt.Y("sum(count)", title="mutations"),
            color=alt.value(phenotype_colors[pheno]),
            row=alt.Row("mutation_type", title=None, spacing=5),
        )
        .properties(width=200, height=75, title=pheno)
        .mark_bar()
        .resolve_scale(y="independent")
    )
    display(base_hist)
There are 255 mutations found in clades with growth estimates